#!/bin/bash
#
# run without epx:
# bash train_deepseekv2_single_gpu.sh --dataset_dir /root/workspace/llama2_dataset --data_format fp8
#
# run with epx:
# RUST_LOG=debug, cargo run --package epx-ccp -- -p 9008
# bash train_deepseekv2_single_gpu.sh --dataset_dir /root/workspace/llama2_dataset --data_format fp8 -u -m --ccp_port 9008
# bash train_deepseekv2_single_gpu.sh --dataset_dir /root/workspace/llama2_dataset --data_format fp8 -u -m --ccp_port 9008 --master_port 23456 -d 2


# 默认参数
DATA_FORMAT="fp8"    # default fp8
DATASET_DIR=""       # required
USE_EPX=0            # fault torerance
EPX_MASTER_PROCESS=0 # fault torerance master process

# print usage
usage() {
    echo "Usage: $0 [OPTIONS]"
    echo "Options:"
    echo "  --data_format      Set data format (default: fp8)."
    echo "  --dataset_dir      Set dataset directory(required)."
    echo "  --root_path        Set epx root path (default: /root/workspace)."
    echo "  -d, --device_id    Set device id (default: 0). only for single GPU training."
    echo "  -u, --use_epx      Set use_epx (default: 0)."
    echo "  -m, --master_proc  Set current process to be master process (default: 0)."
    echo "  --master_addr      Set master_addr (default: host address)."
    echo "  --master_port      Set master port (default: 12345)."
    echo "  --ccp_port         Set ccp_port (default: 9009)."
    echo "  --store_port       Set store_port (default: 45678)."
    echo "  -h, --help         Show this help message and exit."
    echo "  ============= Example: ============="
    echo "  $0 --dataset_dir /root/workspace/llama2_dataset --data_format fp8 # run without epx"
    echo "  $0 --dataset_dir /root/workspace/llama2_dataset --data_format fp8 \
--use_epx --master_proc --ccp_port 9008 # run with epx"
    echo "  $0 -h # Show this help message."
}

setup_epx_env() {
    HOST_ADDR=$(ip addr show bond0 | grep -oP 'inet \K[\d.]+' | head -1)
    MASTER_ADDR=${MASER_ADDR:-$HOST_ADDR} # set master addr default to host addr

    echo "EPX HOST_ADDR: $HOST_ADDR"
    echo "EPX MASTER_ADDR: $MASTER_ADDR"

    export EPX_CCP_ADDR="$MASTER_ADDR"
    export EPX_CCP_PORT=${EPX_CCP_PORT:-9009}
    export EPX_STORE_ADDR="$MASTER_ADDR"
    export EPX_STORE_PORT=${EPX_STORE_PORT:-45678}
    export EPX_LCP_ADDR="$HOST_ADDR"
    export EPX_SESSION="a0515990-ffbf-11ef-8a45-6fde377b8f7a"
    export EPX_GROUP_RANK=1

    export USE_EPX=$USE_EPX

    ROOT_PATH=${ROOT_PATH:-"/root/workspace"}
    EPX_PATH=$ROOT_PATH/epx
    EPX_STORE_PATH="${EPX_PATH}/epx-py/examples/epx_store.py"

    export PYTHONPATH=${EPX_PATH}/epx-py/python:$PYTHONPATH
    export EPX_LCP_BIN="$EPX_PATH/target/debug/epx-lcp"

    if [ "$EPX_MASTER_PROCESS" -ne 0 ]; then
        echo "Starting epx-store... on port $EPX_STORE_PORT"
        python  $EPX_STORE_PATH --addr "$HOST_ADDR" &
        STORE_PID=$!
    fi
}

# parse command-line arguments
while [[ $# -gt 0 ]]; do
    case "$1" in
        --data_format)
            DATA_FORMAT="$2"
            shift
            ;;
        --dataset_dir)
            DATASET_DIR="$2"
            shift
            ;;
        --root_path)
            ROOT_PATH="$2"
            shift
            ;;
        -d|--device_id)
            DEVICE_ID="$2"
            shift
            ;;
        -u|--use_epx)
            USE_EPX=1
            ;;
        -m|--master_proc)
            EPX_MASTER_PROCESS=1
            ;;
        --master_addr)
            MASTER_ADDR="$2"
            shift
            ;;
        --master_port)
            MASTER_PORT="$2"
            shift
            ;;
        --ccp_port)
            EPX_CCP_PORT="$2"
            shift
            ;;
        --store_port)
            EPX_STORE_PORT="$2"
            shift
            ;;
        -h|--help)
            usage
            exit 0
            ;;
        *)
            echo "Unknown option: $1"
            exit 1
            ;;
    esac
    shift
done

if [[ "$USE_EPX" -ne 0 ]]; then
    setup_epx_env
fi

# arguments check
if [[ -z "$DATASET_DIR" ]]; then
    echo "Error: --dataset_dir is required"
    exit 1
fi

# validate data format
if [[ "$DATA_FORMAT" != "fp8" && "$DATA_FORMAT" != "bf16" ]]; then
    echo "Error: --data_format must be 'fp8' or 'bf16'"
    exit 1
fi

# generate unique ID
CURRENT_TIME=$(date "+%Y%m%d_%H%M%S")
mkdir -p "./output/$CURRENT_TIME"

# training parameters
TP_SIZE=1
PP_SIZE=1
EP_SIZE=1
WORLD_SIZE=1
MICRO_BATCH_SIZE=1
NUM_MICROBATCHES=1
(( DP_SIZE = WORLD_SIZE / (TP_SIZE * PP_SIZE) ))
(( GLOBAL_BATCH_SIZE = MICRO_BATCH_SIZE * NUM_MICROBATCHES * DP_SIZE ))
export GPUS_PER_NODE=1
export MOE_NUM_EXPERTS=20
export MOE_ROUTER_GROUP_TOPK=1
export DEVICE_ID=${DEVICE_ID:-0}
export MUSA_VISIBLE_DEVICES=${MUSA_VISIBLE_DEVICES:-"0,1,2,3,4,5,6,7"}

# generate hostfile
ip a | grep -oP 'inet \K[\d.]+' | grep -v '^127\.' | head -1 > hostfile

# set env variable
WORK_HOME="$PWD"
PATCH_HOME="$PWD/../.."
EXPNAME="tp${TP_SIZE}_pp${PP_SIZE}_dp${DP_SIZE}_mbs${MICRO_BATCH_SIZE}_numbs${NUM_MICROBATCHES}_gbs${GLOBAL_BATCH_SIZE}_gpus${WORLD_SIZE}_${DATA_FORMAT}"
DATA_PATH="${DATASET_DIR}/llama_00_text_document"
HOSTFILE="./hostfile"
LOG_FILE="./output/$CURRENT_TIME/$EXPNAME.log"
TOKENIZED_MODEL="${DATASET_DIR}/tokenizer.model"
SCRIPT_FILE="./deepseek-v2-lite/run_pretrain_deepseekv2_musa.sh"
RDZV_ID="$CURRENT_TIME"
MASTER_PORT=${MASTER_PORT:-12345}

# Precision-related configuration
if [[ "$DATA_FORMAT" == "bf16" ]]; then
    # remove FP8 parameters
    sed -i '/--fp8-format hybrid/d; /--fp8-param-gather/d' "$SCRIPT_FILE"

    # add no-gradient-accumulation-fusion parameter
    sed -i '/no-gradient-accumulation-fusion/c\    --no-gradient-accumulation-fusion' "$SCRIPT_FILE"

    echo "Enabled BF16 mode with recompute optimizations"
fi

# training cmd
cmd="bash -c 'cd $WORK_HOME && \
     bash $SCRIPT_FILE $WORK_HOME $PATCH_HOME $EXPNAME $HOSTFILE \"$DATA_PATH\" \
     $TP_SIZE $PP_SIZE $EP_SIZE \
     $MICRO_BATCH_SIZE $GLOBAL_BATCH_SIZE $TOKENIZED_MODEL $RDZV_ID $MASTER_PORT'"

echo "=== Training Configuration ==="
echo "Dataset dir: $DATASET_DIR"
echo "Data format: $DATA_FORMAT"
echo "Hostfile: $(cat hostfile)"
echo "Global batch size: $GLOBAL_BATCH_SIZE"
echo "Command:"
echo "$cmd"

eval "$cmd" &
EPX_PID=$!
wait $EPX_PID

if [ "$USE_EPX" -ne 0 ] && [ "$EPX_MASTER_PROCESS" -ne 0 ]; then
    kill -9 $STORE_PID
    pkill -f $EPX_STORE_PATH
    pkill -f $EPX_STORE_PATH
fi
